import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P

def conv(in_channels,out_channels,kernel_size,stride=1,padding=0,pad_mode="valid"):
    weight = weight_variable()
    return nn.Conv2d(in_channels,out_channels,
                     kernel_size=kernel_size,stride=stride,padding=padding,
                     weight_init=weight,has_bias=False,pad_mode=pad_mode)

def fc_with_initialize(input_channels,out_channels):
    weight = weight_variable()
    bias = weight_variable()
    return nn.Dense(input_channels,out_channels,weight,bias)

def weight_variable():
    return TruncatedNormal(0.02)

class AlexNet(nn.Cell):
    """AlexNet"""
    def __init__(self,num_classes=10, channel=3):
        super(AlexNet,self).__init__()
        super(AlexNet, self).__init__()
        self.conv1 = conv(channel, 96, 11, stride=4)
        self.conv2 = conv(96, 256, 5, pad_mode="same")
        self.conv3 = conv(256, 384, 3, pad_mode="same")
        self.conv4 = conv(384, 384, 3, pad_mode="same")
        self.conv5 = conv(384, 256, 3, pad_mode="same")
        self.relu = nn.ReLU()
        self.max_pool2d = P.MaxPool(ksize=3, strides=2)
        self.flatten = nn.Flatten()
        self.fc1 = fc_with_initialize(6 * 6 * 256, 4096)
        self.fc2 = fc_with_initialize(4096, 4096)
        self.fc3 = fc_with_initialize(4096, num_classes)

    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = self.conv5(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x